# network_cal.py
# functions for network-calibration
#
#    Copyright (C) 2018 Andrew Chael
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.


from __future__ import division
from __future__ import print_function

from builtins import str
from builtins import range
from builtins import object

import numpy as np
import scipy.optimize as opt
import time
import copy
from multiprocessing import cpu_count, Pool

import ehtim.obsdata
import ehtim.parloop as parloop
from . import cal_helpers as calh
import ehtim.observing.obs_helpers as obsh
import ehtim.const_def as ehc

ZBLCUTOFF = 1.e7
MAXIT = 5000

###################################################################################################
# Network-Calibration
###################################################################################################


def network_cal(obs, zbl, sites=[], zbl_uvdist_max=ZBLCUTOFF, method="amp", minimizer_method='BFGS',
                pol='I', pad_amp=0., gain_tol=.2, solution_interval=0.0, scan_solutions=False,
                caltable=False, processes=-1, show_solution=False, debias=True, msgtype='bar'):
    """Network-calibrate a dataset with zero baseline constraints.

       Args:
           obs (Obsdata): The observation to be calibrated
           zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour.
           sites (list): list of sites to include in the network calibration.
                         empty list calibrates all sites
           zbl_uvdist_max (float): maximum uv-distance considered a zero baseline
           method (str): chooses what to calibrate, 'amp', 'phase', or 'both'.
           minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS')
           pol (str): which visibility to compute gains for

           pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
           gain_tol (float): gains that exceed this value will be disfavored by the prior
           solution_interval (float): solution interval in seconds;
                                      one gain is derived for each interval.
                                      If 0.0, a solution is determined for each unique time
           scan_solutions (bool): If True, determine one gain per site per scan.
                                  Supersedes solution_interval

           debias (bool): If True, debias the amplitudes
           caltable (bool): if True, returns a Caltable instead of an Obsdata
           processes (int): number of cores to use in multiprocessing
           show_solution (bool): if True, display the solution as it is calculated
           msgtype (str): type of progress message to be printed, default is 'bar'

       Returns:
           (Obsdata): the calibrated observation, if caltable==False
           (Caltable): the derived calibration table, if caltable==True
    """

    # Here, RRLL means to use both RR and LL (both as proxies for Stokes I)
    # to derive a network calibration solution
    if pol not in ['I', 'Q', 'U', 'V', 'RR', 'LL', 'RRLL']:
        raise Exception("Can only network-calibrate to I, Q, U, V, RR, LL, or RRLL!")
    if pol in ['I', 'Q', 'U', 'V']:
        if obs.polrep != 'stokes':
            raise Exception("netcal pol is a stokes parameter, but obs.polrep!='stokes'")
        # obs = obs.switch_polrep('stokes',pol)
    elif pol in ['RR', 'LL', 'RRLL']:
        if obs.polrep != 'circ':
            raise Exception("netcal pol is RR or LL or RRLL, but obs.polrep!='circ'")
        # obs = obs.switch_polrep('circ',pol)

    # V = model visibility, V' = measured visibility, G_i = site gain
    # G_i * conj(G_j) * V_ij = V'_ij
    if len(sites) == 0:
        print("No stations specified in network cal: defaulting to calibrating all stations!")
        sites = obs.tarr['site']

    # find colocated sites and put into list allclusters
    cluster_data = calh.make_cluster_data(obs, zbl_uvdist_max)

    # get scans
    scans = obs.tlist(t_gather=solution_interval, scan_gather=scan_solutions)
    scans_cal = copy.copy(scans)

    # Make the pool for parallel processing
    if processes > 0:
        counter = parloop.Counter(initval=0, maxval=len(scans))
        if processes > len(scans):
            processes = len(scans)
        print("Using Multiprocessing with %d Processes" % processes)
        pool = Pool(processes=processes, initializer=init, initargs=(counter,))
    elif processes == 0:
        counter = parloop.Counter(initval=0, maxval=len(scans))
        processes = int(cpu_count())
        if processes > len(scans):
            processes = len(scans)
        print("Using Multiprocessing with %d Processes" % processes)
        pool = Pool(processes=processes, initializer=init, initargs=(counter,))
    else:
        print("Not Using Multiprocessing")

    # loop over scans and calibrate
    tstart = time.time()
    if processes > 0:  # with multiprocessing
        scans_cal = np.array(pool.map(get_network_scan_cal, [[i, len(scans), scans[i],
                                                              zbl, sites, cluster_data, obs.polrep, pol,
                                                              method, pad_amp, gain_tol,
                                                              caltable, show_solution, debias, msgtype
                                                             ] for i in range(len(scans))]),
                                                             dtype=object)
    else:  # without multiprocessing
        for i in range(len(scans)):
            obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1)
            scans_cal[i] = network_cal_scan(scans[i], zbl, sites, cluster_data,
                                            polrep=obs.polrep, pol=pol,
                                            method=method, minimizer_method=minimizer_method,
                                            show_solution=show_solution, caltable=caltable,
                                            pad_amp=pad_amp, gain_tol=gain_tol, debias=debias)

    tstop = time.time()
    print("\nnetwork_cal time: %f s" % (tstop - tstart))

    if caltable:  # create and return  a caltable
        allsites = obs.tarr['site']
        caldict = {k: v.reshape(1) for k, v in scans_cal[0].items()}
        for i in range(1, len(scans_cal)):
            row = scans_cal[i]
            if len(row) == 0:
                continue

            for site in allsites:
                try:
                    dat = row[site]
                except KeyError:
                    continue

                try:
                    caldict[site] = np.append(caldict[site], row[site])
                except KeyError:
                    caldict[site] = [dat]

        caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr,
                                           source=obs.source, mjd=obs.mjd, timetype=obs.timetype)
        out = caltable

    else:  # return the calibrated observation
        arglist, argdict = obs.obsdata_args()
        arglist[4] = np.concatenate(scans_cal)
        out = ehtim.obsdata.Obsdata(*arglist, **argdict)

    # close multiprocessing jobs
    if processes != -1:
        pool.close()

    return out


def network_cal_scan(scan, zbl, sites, clustered_sites, polrep='stokes', pol='I',
                     zbl_uvidst_max=ZBLCUTOFF, method="both", minimizer_method='BFGS',
                     show_solution=False, pad_amp=0., gain_tol=.2, caltable=False, debias=True):
    """Network-calibrate a scan with zero baseline constraints.

       Args:
           obs (Obsdata): The observation to be calibrated
           zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour.
           sites (list): list of sites to include in the network calibration.
                         empty list calibrates all sites
           clustered_sites (tuple): information  on clustered sites, returned by make_cluster_data

           polrep (str): 'stokes' or 'circ' to specify the  polarization products in scan
           pol (str): which image polarization to self-calibrate visibilities to
           zbl_uvdist_max (float): maximum uv-distance considered a zero baseline
           method (str): chooses what to calibrate, 'amp', 'phase', or 'both'
           pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
           gain_tol (float): gains that exceed this value will be disfavored by the prior

           debias (bool): If True, debias the amplitudes
           caltable (bool): if True, returns a Caltable instead of an Obsdata
           show_solution (bool): if True, display the solution as it is calculated


       Returns:
           (Obsdata): the calibrated scan, if caltable==False
           (Caltable): the derived calibration table, if caltable==True
    """

    # determine the zero-baseline flux of the scan
    if callable(zbl):
        zbl_scan = np.median(zbl(scan['time']))
    else:
        zbl_scan = zbl

    # clustered site information
    allclusters = clustered_sites[0]
    clusterdict = clustered_sites[1]
    clusterbls = clustered_sites[2]

    # all the sites in the scan
    allsites = list(set(np.hstack((scan['t1'], scan['t2']))))

    if len(sites) == 0:
        print("No stations specified in network cal: defaulting to calibrating all !")
        sites = allsites

    # only include sites that are present
    sites = [s for s in sites if s in allsites]

    # create a dictionary to keep track of gains;
    # sites that aren't network calibrated (no co-located partners) get a value of -1
    # so that they won't be network calibrated; other sites get a unique number
    tkey = {b: a for a, b in enumerate(sites)}
    for cluster in allclusters:
        if len(cluster) == 1:
            tkey[cluster[0]] = -1

    clusterkey = clusterdict

    # restrict solved cluster visibilities to ones present in the scan
    # (this is much faster than allowing many unconstrained variables)
    clusterbls_scan = [set([clusterkey[row['t1']], clusterkey[row['t2']]])
                       for row in scan
                       if len(set([clusterkey[row['t1']], clusterkey[row['t2']]])) == 2]

    # now delete duplicates
    clusterbls = [cluster for cluster in clusterbls if cluster in clusterbls_scan]

    # make two lists of gain keys that relates scan bl gains to solved site ones
    # (-1 means that this station does not have a gain that is being solved for)
    # and make one list of scan keys that relates scan bl visibilities to solved cluster ones
    # (-1 means it's a zero baseline!)

    g1_keys = []
    g2_keys = []
    scan_keys = []
    for row in scan:
        try:
            g1_keys.append(tkey[row['t1']])
        except KeyError:
            g1_keys.append(-1)
        try:
            g2_keys.append(tkey[row['t2']])
        except KeyError:
            g2_keys.append(-1)

        clusternum1 = clusterkey[row['t1']]
        clusternum2 = clusterkey[row['t2']]

        if clusternum1 == clusternum2:  # sites are in the same cluster
            scan_keys.append(-1)
        else:  # sites are not in the same cluster
            bl_index = clusterbls.index(set((clusternum1, clusternum2)))
            scan_keys.append(bl_index)

    # no sites to calibrate on this scan!
    # if np.all(g1_keys == -1):
        # return scan #Doesn't work with the caldict options

    # Start by restricting to visibilities that include baselines to a site with a zero-baseline
    vis_mask = [((row['t1'] in tkey.keys() and tkey[row['t1']] != -1)
                 or (row['t2'] in tkey.keys() and tkey[row['t2']] != -1)) for row in scan]

    # get scan visibilities of the specified polarization
    if pol != 'RRLL':
        vis = scan[ehc.vis_poldict[pol]]
        sigma = scan[ehc.sig_poldict[pol]]
    else:
        vis = np.concatenate([scan[ehc.vis_poldict['RR']], scan[ehc.vis_poldict['LL']]])
        sigma = np.concatenate([scan[ehc.sig_poldict['RR']], scan[ehc.sig_poldict['LL']]])
        vis_mask = np.concatenate([vis_mask, vis_mask])

    if method == 'amp':
        if debias:
            vis = obsh.amp_debias(np.abs(vis), np.abs(sigma))
        else:
            vis = np.abs(vis)

    sigma_inv = 1.0 / np.sqrt(sigma**2 + (pad_amp * np.abs(vis))**2)

    # initial guesses for parameters
    n_gains = len(sites)
    n_clusterbls = len(clusterbls)
    if show_solution:
        print('%d Gains; %d Clusters' % (n_gains, n_clusterbls))

    gpar_guess = np.ones(n_gains, dtype=np.complex128).view(dtype=np.float64)
    vpar_guess = np.ones(n_clusterbls, dtype=np.complex128)
    for i in range(len(scan_keys)):
        if scan_keys[i] < 0:
            continue
        if np.isnan(vis[i]):
            continue
        vpar_guess[scan_keys[i]] = vis[i]

    vpar_guess = vpar_guess.view(dtype=np.float64)
    gvpar_guess = np.hstack((gpar_guess, vpar_guess))

    # error function
    def errfunc(gvpar):

        # all the forward site gains (complex)
        g = gvpar[0:2 * n_gains].astype(np.float64).view(dtype=np.complex128)

        # all the intercluster visibilities (complex)
        v = gvpar[2 * n_gains:].astype(np.float64).view(dtype=np.complex128)

        # choose to only scale ampliltudes or phases
        if method == "phase":
            g = g / np.abs(g)
        elif method == "amp":
            g = np.abs(np.real(g))

        # append the default values to g for missing points
        # and to v for the zero baseline points
        g = np.append(g, 1.)
        v = np.append(v, zbl_scan)

        # scan visibilities are either an intercluster visibility or the fixed zbl
        v_scan = v[scan_keys]
        g1 = g[g1_keys]
        g2 = g[g2_keys]
        if pol == 'RRLL':
            v_scan = np.concatenate([v_scan, v_scan])
            g1 = np.concatenate([g1, g1])
            g2 = np.concatenate([g2, g2])

        if method == 'amp':
            verr = np.abs(vis) - g1 * g2.conj() * np.abs(v_scan)
        else:
            verr = vis - g1 * g2.conj() * v_scan

        chi   = np.abs(verr) * sigma_inv
        chisq = np.sum((chi * chi)[np.isfinite(chi) * vis_mask])

        # prior on the gains
        g_fracerr = gain_tol
        if method == "phase":
            chisq_g = 0 # because |g| == 1 so log(|g|) = 0
        elif method == "amp":
            logg    = np.log(g)
            chisq_g = np.sum(logg * logg) / (g_fracerr * g_fracerr)
        else:
            logabsg = np.log(np.abs(g))
            chisq_g = np.sum(logabsg * logabsg) / (g_fracerr * g_fracerr)

        absv    = np.abs(v)
        vv      = absv * absv
        chisq_v = np.sum(vv * vv) / zbl_scan**4
        return chisq + chisq_g + chisq_v

    if np.max(g1_keys) > -1 or np.max(g2_keys) > -1:
        # run the minimizer to get a solution (but only run if there's at least one gain to fit)
        optdict = {'maxiter': MAXIT}  # minimizer params
        res = opt.minimize(errfunc, gvpar_guess, method=minimizer_method, options=optdict)

        # get solution
        g_fit = res.x[0:2 * n_gains].view(np.complex128)
        v_fit = res.x[2 * n_gains:].view(np.complex128)

        if method == "phase":
            g_fit = g_fit / np.abs(g_fit)
        if method == "amp":
            g_fit = np.abs(np.real(g_fit))

        if show_solution:
            print(np.abs(g_fit))
            print(np.abs(v_fit))
    else:
        g_fit = []
        v_fit = []

    g_fit = np.append(g_fit, 1.)
    v_fit = np.append(v_fit, zbl_scan)

    # Derive a calibration table or apply the solution to the scan
    if caltable:
        allsites = list(set(scan['t1']).union(set(scan['t2'])))

        caldict = {}
        for site in allsites:
            if site in sites:
                site_key = tkey[site]
            else:
                site_key = -1

            # We will *always* set the R and L gain corrections to be equal in network calibration,
            # to avoid breaking polarization consistency relationships
            rscale = g_fit[site_key]**-1
            lscale = g_fit[site_key]**-1

            # Note: we may want to give two entries for the start/stop times
            # when a non-zero solution interval is used
            caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL)

        out = caldict

    else:
        g1_fit = g_fit[g1_keys]
        g2_fit = g_fit[g2_keys]

        gij_inv = (g1_fit * g2_fit.conj())**(-1)

        if polrep == 'stokes':
            # scale visibilities
            for vistype in ['vis', 'qvis', 'uvis', 'vvis']:
                scan[vistype] *= gij_inv
            # scale sigmas
            for sigtype in ['sigma', 'qsigma', 'usigma', 'vsigma']:
                scan[sigtype] *= np.abs(gij_inv)
        elif polrep == 'circ':
            # scale visibilities
            for vistype in ['rrvis', 'llvis', 'rlvis', 'lrvis']:
                scan[vistype] *= gij_inv
            # scale sigmas
            for sigtype in ['rrsigma', 'llsigma', 'rlsigma', 'lrsigma']:
                scan[sigtype] *= np.abs(gij_inv)

        out = scan

    return out


def init(x):
    global counter
    counter = x


def get_network_scan_cal(args):
    return get_network_scan_cal2(*args)


def get_network_scan_cal2(i, n, scan, zbl, sites, cluster_data, polrep, pol,
                          method, pad_amp, gain_tol, caltable, show_solution, debias, msgtype):
    if n > 1:
        global counter
        counter.increment()
        obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1)

    return network_cal_scan(scan, zbl, sites, cluster_data, polrep=polrep, pol=pol,
                            zbl_uvidst_max=ZBLCUTOFF, method=method, caltable=caltable,
                            show_solution=show_solution,
                            pad_amp=pad_amp, gain_tol=gain_tol, debias=debias)
